{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d579aba-7439-42a6-aae4-cf1983095dee",
   "metadata": {},
   "source": [
    "# Example 06 - TorchSigWideband with YOLOv8 Detector\n",
    "This notebook showcases using the Wideband dataset to train a YOLOv8 model.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a026bd-f096-47f3-a262-48ab5defe23e",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab949825-7c48-44f2-852d-56567f5952e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Packages Imports for Training\n",
    "from torchsig.utils.yolo_train import *\n",
    "from datetime import datetime\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bff6d5-4b2d-4db2-97a0-f45843e7cc60",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Package Imports for Testing/Inference\n",
    "from torchsig.datasets.datamodules import WidebandDataModule\n",
    "from torchsig.datasets.signal_classes import torchsig_signals\n",
    "from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity\n",
    "from torchsig.transforms.target_transforms import DescToBBoxFamilyDict\n",
    "from ultralytics import YOLO\n",
    "import torch\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efea64b7-1d22-41bd-81cc-5c39a74c3bdd",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Check or Generate the Wideband Dataset\n",
    "To generate the TorchSigWideband dataset, several parameters are given to the imported `WidebandDataModule` class. These paramters are:\n",
    "- `root` ~ A string to specify the root directory of where to generate and/or read an existing TorchSigWideband dataset\n",
    "- `train` ~ A boolean to specify if the TorchSigWideband dataset should be the training (True) or validation (False) sets\n",
    "- `qa` - A boolean to specify whether to generate a small subset of Wideband (True), or the full dataset (False), default is True\n",
    "- `impaired` ~ A boolean to specify if the TorchSigWideband dataset should be the clean version or the impaired version\n",
    "- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline. Note: these transforms are not called during the dataset generation. The static saved dataset will always be in IQ format. The transform is only called when retrieving data examples.\n",
    "- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline. Note: these target transforms are not called during the dataset generation. The static saved dataset will always be saved as tuples in the LMDB dataset. The target transform is only called when retrieving data examples.\n",
    "\n",
    "A combination of the `train` and the `impaired` booleans determines which of the four (4) distinct TorchSigWideband datasets will be instantiated:\n",
    "| `impaired` | `qa` | Result |\n",
    "| ---------- | ---- | ------- |\n",
    "| `False` | `False` | Clean datasets of train=250k examples and val=25k examples |\n",
    "| `False` | `True` | Clean datasets of train=250 examples and val=250 examples |\n",
    "| `True` | `False` | Impaired datasets of train=250k examples and val=25k examples |\n",
    "| `True` | `True` | Impaired datasets of train=250 examples and val=250 examples |\n",
    "\n",
    "The final option of the impaired validation set is the dataset to be used when reporting any results with the official TorchSigWideband dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6e8bfb-92f3-4615-afac-7568350a2461",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate TorchSigWideband DataModule\n",
    "root = \"./datasets/wideband\"\n",
    "impaired = True\n",
    "qa = True\n",
    "fft_size = 512\n",
    "num_classes = len(torchsig_signals.class_list)\n",
    "batch_size = 1\n",
    "\n",
    "transform = Compose([    \n",
    "])\n",
    "\n",
    "target_transform = Compose([\n",
    "    DescToBBoxFamilyDict()\n",
    "])\n",
    "\n",
    "datamodule = WidebandDataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2964745c-4fda-4cdc-93ce-6ac67a5e1da8",
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_train = datamodule.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_train))\n",
    "data, label = wideband_train[idx]\n",
    "print(\"Dataset length: {}\".format(len(wideband_train)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label: {}\".format(label))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2938efda-7b7f-42e9-8bcf-022ebcaf2d32",
   "metadata": {},
   "source": [
    "## Prepare YOLO trainer and Model\n",
    "Next, the datasets are rewritten to disk that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. \n",
    "\n",
    "Additionally, create a yaml file for dataset configuration. See [Ultralytics: Train Custom Data - Create dataset.yaml](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#21-create-datasetyaml)\n",
    "\n",
    "Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n.pt`\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8f5ad86-bb23-4396-8cd3-3b1c1c9a32b2",
   "metadata": {},
   "source": [
    "### Explanation of the `overrides` Dictionary\n",
    "\n",
    "The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `wbdata.yaml`. However, you can customize in the notebook. See [Ultralytics Train Settings](https://docs.ultralytics.com/modes/train/#train-settings) to learn more.\n",
    "\n",
    "Example:\n",
    "\n",
    "```python\n",
    "overrides = {'model': 'yolov8n.pt', 'epochs': 100, 'data': 'wbdata.yaml', 'device': 0, 'imgsz': 512, 'single_cls': True}\n",
    "```\n",
    "A .yaml is necessary for training. Look at `06_yolo.yaml` in the examples directory. It will contain the path to your torchsig data.\n",
    "\n",
    "\n",
    "### Dataset Location Warning\n",
    "\n",
    "There must exist a datasets directory at `/path/to/torchsig/datasets`.\n",
    "\n",
    "This example assumes that you have generated `train` and `val` lmdb wideband datasets at `./datasets/wideband/`\n",
    "\n",
    "You can also specify an absolute path to your dataset in `06_yolo.yaml`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "783e7f33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define dataset variables for yaml file\n",
    "config_name = \"06_yolo.yaml\"\n",
    "class_list = [\"ask\", \"fsk\", \"ofdm\", \"pam\", \"psk\", \"qam\"]\n",
    "classes = {v: k for v, k in enumerate(class_list)}\n",
    "num_classes = len(classes)\n",
    "yolo_root = \"./wideband/\" # train/val images (relative to './datasets``\n",
    "\n",
    "# define overrides\n",
    "# Note: You can change use of GPU(s) or CPU by overriding the device\n",
    "# GPU: device=0 or device=0,1\n",
    "# CPU: device=\"cpu\"\n",
    "overrides = dict(\n",
    "    model = \"yolov8n.pt\",\n",
    "    project = \"yolo\",\n",
    "    name = \"06_example\",\n",
    "    epochs = 10,\n",
    "    imgsz = 512,\n",
    "    data = config_name,\n",
    "    device = 0 if torch.cuda.is_available() else \"cpu\",\n",
    "    single_cls = True,\n",
    "    batch = 32,\n",
    "    workers = 8\n",
    "\n",
    ")\n",
    "\n",
    "# create yaml file for trainer\n",
    "yolo_config = dict(\n",
    "    overrides = overrides,\n",
    "    train = yolo_root,\n",
    "    val = yolo_root,\n",
    "    nc = num_classes,\n",
    "    names = classes\n",
    ")\n",
    "\n",
    "with open(config_name, 'w+') as file:\n",
    "    yaml.dump(yolo_config, file, default_flow_style=False)\n",
    "\n",
    "print(f\"Creating experiment -> {overrides['name']}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a6f0586-c1d5-4230-b9a4-4f9e6413c635",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Yolo_Trainer(overrides=overrides)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22fa62fc-01d9-46b7-9cd2-38643f54d1de",
   "metadata": {},
   "source": [
    "## Train\n",
    "Train YOLO. See [Ultralytics Train](https://docs.ultralytics.com/modes/train/#train-settings) for training hyperparameter options.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7e5f57-3b5b-4074-b87c-6e18c7b973af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad0c459-2367-4f7d-89ae-5f4cfd8f545f",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "Check model performance from training. From here, you can use the trained model to test on prepared data (numpy image arrays of spectrograms)\n",
    "\n",
    "Will load example from Torchsig\n",
    "\n",
    "model_path is path to best.pt from your training session. Path is printed at the end of training.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c45fa7bf-fa7d-4896-9022-4e608d93e5a4",
   "metadata": {},
   "source": [
    "## Generate and Instantiate Wideband Test Dataset\n",
    "After generating the Wideband dataset (see `03_example_widebandsig_dataset.ipynb`), we can instantiate it with the needed transforms. Change `root` to test dataset path.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "694c33f1-d718-40ae-861d-26c3431a6f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_path = './datasets/wideband_test' #Should differ from your training dataset\n",
    "\n",
    "transform = Compose([\n",
    "    Spectrogram(nperseg=512, noverlap=0, nfft=512, mode='psd'),\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "    SpectrogramImage(), \n",
    "    ])\n",
    "\n",
    "test_data = WidebandDataModule(\n",
    "    root=test_path,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=None,\n",
    "    batch_size=batch_size\n",
    ")\n",
    "\n",
    "test_data.prepare_data()\n",
    "test_data.setup(\"fit\")\n",
    "\n",
    "wideband_test = test_data.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_test))\n",
    "data, label = wideband_test[idx]\n",
    "print(\"Dataset length: {}\".format(len(wideband_test)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "\n",
    "samples = []\n",
    "labels = []\n",
    "for i in range(10):\n",
    "    idx = np.random.randint(len(wideband_test))\n",
    "    sample, label = wideband_test[idx]\n",
    "    lb = [l['class_name'] for l in label]\n",
    "    samples.append(sample)\n",
    "    labels.append(lb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de9947fc-9636-415c-af50-26d7e5d95953",
   "metadata": {},
   "source": [
    "### Load model \n",
    "The model path is printed after training. Use the best.pt weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2bb8735-ce6c-482b-9675-302f6848b0ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YOLO(trainer.best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec6d1966-31a8-4ab8-9ce9-8557d5e3c956",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference will be saved to path printed after predict. \n",
    "results = model.predict(samples, save=True, imgsz=512, conf=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a07a684",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef3b004a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot prediction results\n",
    "rows = 3\n",
    "cols = 3\n",
    "fig = plt.figure(figsize=(15, 15)) \n",
    "results_dir = results[0].save_dir\n",
    "\n",
    "for y, result in enumerate(results[:9]):\n",
    "    imgpath = os.path.join(results_dir, \"image\" + str(y) + \".jpg\")\n",
    "    fig.add_subplot(rows, cols, y + 1) \n",
    "    img = cv2.imread(imgpath)\n",
    "    plt.imshow(img)\n",
    "    plt.title(str(labels[y]), fontsize='small', loc='left')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
